# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

"""
Endpoint generation logic for FastAPI OpenAPI generation.
"""

import inspect
import re
import types
import typing
from typing import Annotated, Any, get_args, get_origin

from fastapi import FastAPI
from fastapi.params import Body as FastAPIBody
from pydantic import Field, create_model

from llama_stack.log import get_logger
from llama_stack_api import Api
from llama_stack_api.schema_utils import get_registered_schema_info

from . import app as app_module
from .state import _extra_body_fields, register_dynamic_model

logger = get_logger(name=__name__, category="core")

type QueryParameter = tuple[str, type, Any, bool]


def _to_pascal_case(segment: str) -> str:
    tokens = re.findall(r"[A-Za-z]+|\d+", segment)
    return "".join(token.capitalize() for token in tokens if token)


def _compose_request_model_name(api: Api, method_name: str, variant: str | None = None) -> str:
    """Generate a deterministic model name from the protocol method."""

    def _to_pascal_from_snake(value: str) -> str:
        return "".join(segment.capitalize() for segment in value.split("_") if segment)

    base_name = _to_pascal_from_snake(method_name)
    if not base_name:
        base_name = _to_pascal_case(api.value)
    base_name = f"{base_name}Request"
    if variant:
        base_name = f"{base_name}{variant}"
    return base_name


def _extract_path_parameters(path: str) -> list[dict[str, Any]]:
    """Extract path parameters from a URL path and return them as OpenAPI parameter definitions."""
    matches = re.findall(r"\{([^}:]+)(?::[^}]+)?\}", path)
    return [
        {
            "name": param_name,
            "in": "path",
            "required": True,
            "schema": {"type": "string"},
            "description": f"Path parameter: {param_name}",
        }
        for param_name in matches
    ]


def _create_endpoint_with_request_model(
    request_model: type, response_model: type | None, operation_description: str | None
):
    """Create an endpoint function with a request body model."""

    async def endpoint(request: request_model) -> response_model:
        return response_model() if response_model else {}

    if operation_description:
        endpoint.__doc__ = operation_description
    return endpoint


def _build_field_definitions(query_parameters: list[QueryParameter], use_any: bool = False) -> dict[str, tuple]:
    """Build field definitions for a Pydantic model from query parameters."""
    from typing import Any

    field_definitions = {}
    for param_name, param_type, default_value, _ in query_parameters:
        if use_any:
            field_definitions[param_name] = (Any, ... if default_value is inspect.Parameter.empty else default_value)
            continue

        base_type = param_type
        extracted_field = None
        if get_origin(param_type) is Annotated:
            args = get_args(param_type)
            if args:
                base_type = args[0]
                for arg in args[1:]:
                    if isinstance(arg, Field):
                        extracted_field = arg
                        break

        try:
            if extracted_field:
                field_definitions[param_name] = (base_type, extracted_field)
            else:
                field_definitions[param_name] = (
                    base_type,
                    ... if default_value is inspect.Parameter.empty else default_value,
                )
        except (TypeError, ValueError):
            field_definitions[param_name] = (Any, ... if default_value is inspect.Parameter.empty else default_value)

    # Ensure all parameters are included
    expected_params = {name for name, _, _, _ in query_parameters}
    missing = expected_params - set(field_definitions.keys())
    if missing:
        for param_name, _, default_value, _ in query_parameters:
            if param_name in missing:
                field_definitions[param_name] = (
                    Any,
                    ... if default_value is inspect.Parameter.empty else default_value,
                )

    return field_definitions


def _create_dynamic_request_model(
    api: Api,
    webmethod,
    method_name: str,
    http_method: str,
    query_parameters: list[QueryParameter],
    use_any: bool = False,
    variant_suffix: str | None = None,
) -> type | None:
    """Create a dynamic Pydantic model for request body."""
    try:
        field_definitions = _build_field_definitions(query_parameters, use_any)
        if not field_definitions:
            return None
        model_name = _compose_request_model_name(api, method_name, variant_suffix or None)
        request_model = create_model(model_name, **field_definitions)
        return register_dynamic_model(model_name, request_model)
    except Exception:
        return None


def _build_signature_params(
    query_parameters: list[QueryParameter],
) -> tuple[list[inspect.Parameter], dict[str, type]]:
    """Build signature parameters and annotations from query parameters."""
    signature_params = []
    param_annotations = {}
    for param_name, param_type, default_value, _ in query_parameters:
        param_annotations[param_name] = param_type
        signature_params.append(
            inspect.Parameter(
                param_name,
                inspect.Parameter.POSITIONAL_OR_KEYWORD,
                default=default_value if default_value is not inspect.Parameter.empty else inspect.Parameter.empty,
                annotation=param_type,
            )
        )
    return signature_params, param_annotations


def _extract_operation_description_from_docstring(api: Api, method_name: str) -> str | None:
    """Extract operation description from the actual function docstring."""
    func = app_module._get_protocol_method(api, method_name)
    if not func or not func.__doc__:
        return None

    doc_lines = func.__doc__.split("\n")
    description_lines = []
    metadata_markers = (":param", ":type", ":return", ":returns", ":raises", ":exception", ":yield", ":yields", ":cvar")

    for line in doc_lines:
        if line.strip().startswith(metadata_markers):
            break
        description_lines.append(line)

    description = "\n".join(description_lines).strip()
    return description if description else None


def _extract_response_description_from_docstring(webmethod, response_model, api: Api, method_name: str) -> str:
    """Extract response description from the actual function docstring."""
    func = app_module._get_protocol_method(api, method_name)
    if not func or not func.__doc__:
        return "Successful Response"
    for line in func.__doc__.split("\n"):
        if line.strip().startswith(":returns:"):
            if desc := line.strip()[9:].strip():
                return desc
    return "Successful Response"


def _get_tag_from_api(api: Api) -> str:
    """Extract a tag name from the API enum for API grouping."""
    return api.value.replace("_", " ").title()


def _is_file_or_form_param(param_type: Any) -> bool:
    """Check if a parameter type is annotated with File() or Form()."""
    if get_origin(param_type) is Annotated:
        args = get_args(param_type)
        if len(args) > 1:
            # Check metadata for File or Form
            for metadata in args[1:]:
                # Check if it's a File or Form instance
                if hasattr(metadata, "__class__"):
                    class_name = metadata.__class__.__name__
                    if class_name in ("File", "Form"):
                        return True
    return False


def _is_extra_body_field(metadata_item: Any) -> bool:
    """Check if a metadata item is an ExtraBodyField instance."""
    from llama_stack_api.schema_utils import ExtraBodyField

    return isinstance(metadata_item, ExtraBodyField)


def _should_embed_parameter(param_type: Any) -> bool:
    """Determine whether a parameter should be embedded (wrapped) in the request body."""
    if get_origin(param_type) is Annotated:
        args = get_args(param_type)
        metadata = args[1:] if len(args) > 1 else []
        for metadata_item in metadata:
            if isinstance(metadata_item, FastAPIBody):
                # FastAPI treats embed=None as False, so default to False when unset.
                return bool(metadata_item.embed)
    # Unannotated parameters default to embed=True through create_dynamic_typed_route.
    return True


def _is_async_iterator_type(type_obj: Any) -> bool:
    """Check if a type is AsyncIterator or AsyncIterable."""
    from collections.abc import AsyncIterable, AsyncIterator

    origin = get_origin(type_obj)
    if origin is None:
        # Check if it's the class itself
        return type_obj in (AsyncIterator, AsyncIterable) or (
            hasattr(type_obj, "__origin__") and type_obj.__origin__ in (AsyncIterator, AsyncIterable)
        )
    return origin in (AsyncIterator, AsyncIterable)


def _extract_response_models_from_union(union_type: Any) -> tuple[type | None, type | None]:
    """
    Extract non-streaming and streaming response models from a union type.

    Returns:
        tuple: (non_streaming_model, streaming_model)
    """
    non_streaming_model = None
    streaming_model = None

    args = get_args(union_type)
    for arg in args:
        # Check if it's an AsyncIterator
        if _is_async_iterator_type(arg):
            # Extract the type argument from AsyncIterator[T]
            iterator_args = get_args(arg)
            if iterator_args:
                inner_type = iterator_args[0]
                # Check if the inner type is a registered schema (union type)
                # or a Pydantic model
                if hasattr(inner_type, "model_json_schema"):
                    streaming_model = inner_type
                else:
                    # Might be a registered schema - check if it's registered
                    if get_registered_schema_info(inner_type):
                        # We'll need to look this up later, but for now store the type
                        streaming_model = inner_type
        elif hasattr(arg, "model_json_schema"):
            # Non-streaming Pydantic model
            if non_streaming_model is None:
                non_streaming_model = arg

    return non_streaming_model, streaming_model


def _find_models_for_endpoint(
    webmethod, api: Api, method_name: str, is_post_put: bool = False
) -> tuple[type | None, type | None, list[tuple[str, type, Any]], list[inspect.Parameter], type | None, str | None]:
    """
    Find appropriate request and response models for an endpoint by analyzing the actual function signature.
    This uses the protocol function to determine the correct models dynamically.

    Args:
        webmethod: The webmethod metadata
        api: The API enum for looking up the function
        method_name: The method name (function name)
        is_post_put: Whether this is a POST, PUT, or PATCH request (GET requests should never have request bodies)

    Returns:
        tuple: (request_model, response_model, query_parameters, file_form_params, streaming_response_model, response_schema_name)
        where query_parameters is a list of (name, type, default_value, should_embed) tuples
        and file_form_params is a list of inspect.Parameter objects for File()/Form() params
        and streaming_response_model is the model for streaming responses (AsyncIterator content)
    """
    route_descriptor = f"{webmethod.method or 'UNKNOWN'} {webmethod.route}"
    try:
        # Get the function from the protocol
        func = app_module._get_protocol_method(api, method_name)
        if not func:
            logger.warning("No protocol method for %s.%s (%s)", api, method_name, route_descriptor)
            return None, None, [], [], None, None

        # Analyze the function signature
        sig = inspect.signature(func)

        # Find request model and collect all body parameters
        request_model = None
        query_parameters: list[QueryParameter] = []
        file_form_params = []
        path_params = set()
        extra_body_params = []
        response_schema_name = None

        # Extract path parameters from the route
        if webmethod and hasattr(webmethod, "route"):
            path_matches = re.findall(r"\{([^}:]+)(?::[^}]+)?\}", webmethod.route)
            path_params = set(path_matches)

        for param_name, param in sig.parameters.items():
            if param_name == "self":
                continue

            # Skip *args and **kwargs parameters - these are not real API parameters
            if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
                continue

            # Check if this is a path parameter
            if param_name in path_params:
                # Path parameters are handled separately, skip them
                continue

            # Check if it's a File() or Form() parameter - these need special handling
            param_type = param.annotation
            param_should_embed = _should_embed_parameter(param_type)
            if _is_file_or_form_param(param_type):
                # File() and Form() parameters must be in the function signature directly
                # They cannot be part of a Pydantic model
                file_form_params.append(param)
                continue

            # Check for ExtraBodyField in Annotated types
            is_extra_body = False
            extra_body_description = None
            if get_origin(param_type) is Annotated:
                args = get_args(param_type)
                base_type = args[0] if args else param_type
                metadata = args[1:] if len(args) > 1 else []

                # Check if any metadata item is an ExtraBodyField
                for metadata_item in metadata:
                    if _is_extra_body_field(metadata_item):
                        is_extra_body = True
                        extra_body_description = metadata_item.description
                        break

                if is_extra_body:
                    # Store as extra body parameter - exclude from request model
                    extra_body_params.append((param_name, base_type, extra_body_description))
                    continue
                param_type = base_type

            # Check if it's a Pydantic model (for POST/PUT requests)
            if hasattr(param_type, "model_json_schema"):
                query_parameters.append((param_name, param_type, param.default, param_should_embed))
            else:
                # Regular annotated parameter (but not File/Form, already handled above)
                query_parameters.append((param_name, param_type, param.default, param_should_embed))

        # Store extra body fields for later use in post-processing
        # We'll store them when the endpoint is created, as we need the full path
        # For now, attach to the function for later retrieval
        if extra_body_params:
            func._extra_body_params = extra_body_params  # type: ignore

        # If there's exactly one body parameter and it's a Pydantic model, use it directly
        # Otherwise, we'll create a combined request model from all parameters
        # BUT: For GET requests, never create a request body - all parameters should be query parameters
        if is_post_put and len(query_parameters) == 1:
            param_name, param_type, default_value, should_embed = query_parameters[0]
            if hasattr(param_type, "model_json_schema") and not should_embed:
                request_model = param_type
                query_parameters = []  # Clear query_parameters so we use the single model

        # Find response model from return annotation
        # Also detect streaming response models (AsyncIterator)
        response_model = None
        streaming_response_model = None
        return_annotation = sig.return_annotation
        if return_annotation != inspect.Signature.empty:
            origin = get_origin(return_annotation)
            if hasattr(return_annotation, "model_json_schema"):
                response_model = return_annotation
            elif origin is Annotated:
                # Handle Annotated return types
                args = get_args(return_annotation)
                if args:
                    # Check if the first argument is a Pydantic model
                    if hasattr(args[0], "model_json_schema"):
                        response_model = args[0]
                    else:
                        # Check if the first argument is a union type
                        inner_origin = get_origin(args[0])
                        if inner_origin is not None and (
                            inner_origin is types.UnionType or inner_origin is typing.Union
                        ):
                            response_model, streaming_response_model = _extract_response_models_from_union(args[0])
            elif origin is not None and (origin is types.UnionType or origin is typing.Union):
                # Handle union types - extract both non-streaming and streaming models
                response_model, streaming_response_model = _extract_response_models_from_union(return_annotation)
            else:
                try:
                    from fastapi import Response as FastAPIResponse
                except ImportError:
                    fastapi_response_cls = None
                else:
                    fastapi_response_cls = FastAPIResponse
                try:
                    from starlette.responses import Response as StarletteResponse
                except ImportError:
                    starlette_response_cls = None
                else:
                    starlette_response_cls = StarletteResponse

                response_types = tuple(t for t in (fastapi_response_cls, starlette_response_cls) if t is not None)
                if response_types and any(return_annotation is t for t in response_types):
                    response_schema_name = "Response"

        return (
            request_model,
            response_model,
            query_parameters,
            file_form_params,
            streaming_response_model,
            response_schema_name,
        )

    except Exception as exc:
        logger.warning(
            "Failed to analyze endpoint %s.%s (%s): %s", api, method_name, route_descriptor, exc, exc_info=True
        )
        return None, None, [], [], None, None


def _create_fastapi_endpoint(app: FastAPI, route, webmethod, api: Api):
    """Create a FastAPI endpoint from a discovered route and webmethod."""
    path = route.path
    raw_methods = route.methods or set()
    method_list = sorted({method.upper() for method in raw_methods if method and method.upper() != "HEAD"})
    if not method_list:
        method_list = ["GET"]
    primary_method = method_list[0]
    name = route.name
    fastapi_path = path.replace("{", "{").replace("}", "}")
    is_post_put = any(method in ["POST", "PUT", "PATCH"] for method in method_list)

    (
        request_model,
        response_model,
        query_parameters,
        file_form_params,
        streaming_response_model,
        response_schema_name,
    ) = _find_models_for_endpoint(webmethod, api, name, is_post_put)
    operation_description = _extract_operation_description_from_docstring(api, name)
    response_description = _extract_response_description_from_docstring(webmethod, response_model, api, name)

    # Retrieve and store extra body fields for this endpoint
    func = app_module._get_protocol_method(api, name)
    extra_body_params = getattr(func, "_extra_body_params", []) if func else []
    if extra_body_params:
        for method in method_list:
            key = (fastapi_path, method.upper())
            _extra_body_fields[key] = extra_body_params

    if is_post_put and not request_model and not file_form_params and query_parameters:
        request_model = _create_dynamic_request_model(
            api, webmethod, name, primary_method, query_parameters, use_any=False
        )
        if not request_model:
            request_model = _create_dynamic_request_model(
                api, webmethod, name, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
            )
        if request_model:
            query_parameters = []

    if file_form_params and is_post_put:
        signature_params = list(file_form_params)
        param_annotations = {param.name: param.annotation for param in file_form_params}
        for param_name, param_type, default_value, _ in query_parameters:
            signature_params.append(
                inspect.Parameter(
                    param_name,
                    inspect.Parameter.POSITIONAL_OR_KEYWORD,
                    default=default_value if default_value is not inspect.Parameter.empty else inspect.Parameter.empty,
                    annotation=param_type,
                )
            )
            param_annotations[param_name] = param_type

        async def file_form_endpoint():
            return response_model() if response_model else {}

        if operation_description:
            file_form_endpoint.__doc__ = operation_description
        file_form_endpoint.__signature__ = inspect.Signature(signature_params)
        file_form_endpoint.__annotations__ = param_annotations
        endpoint_func = file_form_endpoint
    elif request_model and response_model:
        endpoint_func = _create_endpoint_with_request_model(request_model, response_model, operation_description)
    elif request_model:
        endpoint_func = _create_endpoint_with_request_model(request_model, None, operation_description)
    elif response_model and query_parameters:
        if is_post_put:
            request_model = _create_dynamic_request_model(
                api, webmethod, name, primary_method, query_parameters, use_any=False
            )
            if not request_model:
                request_model = _create_dynamic_request_model(
                    api, webmethod, name, primary_method, query_parameters, use_any=True, variant_suffix="Loose"
                )

            if request_model:
                endpoint_func = _create_endpoint_with_request_model(
                    request_model, response_model, operation_description
                )
            else:

                async def empty_endpoint() -> response_model:
                    return response_model() if response_model else {}

                if operation_description:
                    empty_endpoint.__doc__ = operation_description
                endpoint_func = empty_endpoint
        else:
            sorted_params = sorted(query_parameters, key=lambda x: (x[2] is not inspect.Parameter.empty, x[0]))
            signature_params, param_annotations = _build_signature_params(sorted_params)

            async def query_endpoint():
                return response_model()

            if operation_description:
                query_endpoint.__doc__ = operation_description
            query_endpoint.__signature__ = inspect.Signature(signature_params)
            query_endpoint.__annotations__ = param_annotations
            endpoint_func = query_endpoint
    elif response_model:

        async def response_only_endpoint() -> response_model:
            return response_model()

        if operation_description:
            response_only_endpoint.__doc__ = operation_description
        endpoint_func = response_only_endpoint
    elif query_parameters:
        signature_params, param_annotations = _build_signature_params(query_parameters)

        async def params_only_endpoint():
            return {}

        if operation_description:
            params_only_endpoint.__doc__ = operation_description
        params_only_endpoint.__signature__ = inspect.Signature(signature_params)
        params_only_endpoint.__annotations__ = param_annotations
        endpoint_func = params_only_endpoint
    else:
        # Endpoint with no parameters and no response model
        # If we have a response_model from the function signature, use it even if _find_models_for_endpoint didn't find it
        # This can happen if there was an exception during model finding
        if response_model is None:
            # Try to get response model directly from the function signature as a fallback
            func = app_module._get_protocol_method(api, name)
            if func:
                try:
                    sig = inspect.signature(func)
                    return_annotation = sig.return_annotation
                    if return_annotation != inspect.Signature.empty:
                        if hasattr(return_annotation, "model_json_schema"):
                            response_model = return_annotation
                        elif get_origin(return_annotation) is Annotated:
                            args = get_args(return_annotation)
                            if args and hasattr(args[0], "model_json_schema"):
                                response_model = args[0]
                except Exception:
                    pass

        if response_model:

            async def no_params_endpoint() -> response_model:
                return response_model() if response_model else {}
        else:

            async def no_params_endpoint():
                return {}

        if operation_description:
            no_params_endpoint.__doc__ = operation_description
        endpoint_func = no_params_endpoint

    # Build response content with both application/json and text/event-stream if streaming
    response_content: dict[str, Any] = {}
    if response_model:
        response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_model.__name__}"}}
    elif response_schema_name:
        response_content["application/json"] = {"schema": {"$ref": f"#/components/schemas/{response_schema_name}"}}
    if streaming_response_model:
        # Get the schema name for the streaming model
        # It might be a registered schema or a Pydantic model
        streaming_schema_name = None
        # Check if it's a registered schema first (before checking __name__)
        # because registered schemas might be Annotated types
        if schema_info := get_registered_schema_info(streaming_response_model):
            streaming_schema_name = schema_info.name
        elif hasattr(streaming_response_model, "__name__"):
            streaming_schema_name = streaming_response_model.__name__

        if streaming_schema_name:
            response_content["text/event-stream"] = {
                "schema": {"$ref": f"#/components/schemas/{streaming_schema_name}"}
            }

    # If no content types, use empty schema
    # Add the endpoint to the FastAPI app
    is_deprecated = webmethod.deprecated or False
    route_kwargs = {
        "name": name,
        "tags": [_get_tag_from_api(api)],
        "deprecated": is_deprecated,
        "responses": {
            400: {"$ref": "#/components/responses/BadRequest400"},
            429: {"$ref": "#/components/responses/TooManyRequests429"},
            500: {"$ref": "#/components/responses/InternalServerError500"},
            "default": {"$ref": "#/components/responses/DefaultError"},
        },
    }
    success_response: dict[str, Any] = {"description": response_description}
    if response_content:
        success_response["content"] = response_content
    route_kwargs["responses"][200] = success_response

    # FastAPI needs response_model parameter to properly generate OpenAPI spec
    # Use the non-streaming response model if available
    if response_model:
        route_kwargs["response_model"] = response_model

    method_map = {"GET": app.get, "POST": app.post, "PUT": app.put, "DELETE": app.delete, "PATCH": app.patch}
    for method in method_list:
        if handler := method_map.get(method):
            handler(fastapi_path, **route_kwargs)(endpoint_func)
